{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adjacency Spectral Embed\n",
    "\n",
    "This demo shows how to use the Adjacency Spectral Embed (ASE) class. We will then use ASE to show how two communities from a stochastic block model graph can be found in the embedded space using k-means. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.metrics import adjusted_rand_score\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "from graspologic.embed import AdjacencySpectralEmbed\n",
    "from graspologic.simulations import sbm\n",
    "from graspologic.plot import heatmap, pairplot\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "np.random.seed(8889)\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Generation\n",
    "ASE is a method for estimating the latent positions of a network modeled as a Random Dot Product Graph (RDPG). This embedding is both a form of dimensionality reduction for a graph and a way of fitting a generative model to graph data. We first generate two 2-block SBMs: one directed, and one undirected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define parameters\n",
    "n_verts = 100\n",
    "labels_sbm = n_verts * [0] + n_verts * [1]\n",
    "P = np.array([[0.8, 0.2], \n",
    "              [0.2, 0.8]])\n",
    "\n",
    "# Generate SBMs from parameters\n",
    "undirected_sbm = sbm(2 * [n_verts], P)\n",
    "directed_sbm = sbm(2 * [n_verts], P, directed=True)\n",
    "\n",
    "# Plot both SBMs\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 8))\n",
    "heatmap(undirected_sbm, title='2-block SBM (undirected)', inner_hier_labels=labels_sbm, ax=axes[0])\n",
    "heatmap(directed_sbm, title='2-block SBM (directed)', inner_hier_labels=labels_sbm, ax=axes[1]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embedding: Undirected Case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now use the AdjacencySpectralEmbed class to embed the adjacency matrix into lower-dimensional space.  \n",
    "If no parameters are given to the AdjacencySpectralEmbed class, it will automatically choose the number of dimensions to embed into."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# instantiate an ASE object\n",
    "ase = AdjacencySpectralEmbed()\n",
    "\n",
    "# call its fit_transform method to generate latent positions\n",
    "Xhat = ase.fit_transform(undirected_sbm)\n",
    "_ = pairplot(Xhat, title='SBM adjacency spectral embedding')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embedding: Directed Case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the graph is directed, we will get two outputs roughly corresponding to the \"out\" and \"in\" latent positions, since these are no longer the same. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transform in directed case\n",
    "ase = AdjacencySpectralEmbed()\n",
    "Xhat, Yhat = ase.fit_transform(directed_sbm)\n",
    "\n",
    "# Plot both embeddings\n",
    "pairplot(Xhat, title='SBM adjacency spectral embedding \"out\"')\n",
    "_ = pairplot(Yhat, title='SBM adjacency spectral embedding \"in\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dimension specification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One can also specify the parameters for embedding.  \n",
    "Here, we specify the number of embedded dimensions and change the SVD solver used to compute the embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fit and transform\n",
    "ase = AdjacencySpectralEmbed(n_components=2, algorithm='truncated')\n",
    "Xhat = ase.fit_transform(undirected_sbm)\n",
    "\n",
    "# plot\n",
    "pairplot(Xhat, title='2-component embedding', height=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clustering in the embedded space\n",
    "Now, we will use the Euclidian representation of the graph to apply a standard clustering algorithm like k-means. We start with an SBM model where the 2 blocks have the exact same connection probabilities (effectively giving us an ER model graph). In this case, k-means will not be able to distinguish among the two embedded blocks. As the connections between the blocks become more distinct, the clustering will improve. For each graph, we plot its adjacency matrix, the predicted k-means cluster labels in the embedded space, and the error as a function of the true labels. Adjusted Rand Index (ARI) is a measure of clustering accuracy, where 1 is perfect clustering relative to ground truth. Error rate is simply the proportion of correctly labeled nodes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = {'Right':(0,0.7,0.2),\n",
    "           'Wrong':(0.8,0.1,0.1)}\n",
    "\n",
    "for insularity in np.linspace(0.5, 0.625, 4):\n",
    "    P = np.array([[insularity, 1-insularity], [1-insularity, insularity]])\n",
    "    sampled_sbm = sbm(2 * [n_verts], P)\n",
    "    Xhat = AdjacencySpectralEmbed(n_components=2).fit_transform(sampled_sbm)\n",
    "    labels_kmeans = KMeans(n_clusters=2).fit_predict(Xhat)\n",
    "    ari = adjusted_rand_score(labels_sbm, labels_kmeans)\n",
    "    error = labels_sbm - labels_kmeans\n",
    "    error = error != 0\n",
    "    # sometimes the labels given by kmeans will be the inverse of ours\n",
    "    if np.sum(error) / (2 * n_verts) > 0.5:\n",
    "        error = error == 0\n",
    "    error_rate = np.sum(error) / (2 * n_verts)\n",
    "    error_label = (2 * n_verts) * ['Right']\n",
    "    error_label = np.array(error_label)\n",
    "    error_label[error] = 'Wrong'\n",
    "    \n",
    "    heatmap(sampled_sbm, title=f'Insularity: {str(insularity)[:5]}', \n",
    "            inner_hier_labels=labels_sbm)\n",
    "    pairplot(Xhat,\n",
    "             labels=labels_kmeans,\n",
    "             title=f'KMeans on embedding, ARI: {str(ari)[:5]}',\n",
    "             legend_name='Predicted label',\n",
    "             height=3.5,\n",
    "             palette='muted',)\n",
    "    pairplot(Xhat,\n",
    "             labels=error_label,\n",
    "             title=f'Error from KMeans, Error rate: {str(error_rate)}',\n",
    "             legend_name='Error label',\n",
    "             height=3.5,\n",
    "             palette=palette,)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}